import copy
import glob
import os
import time
from collections import deque
import pickle

import gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from a2c_ppo_acktr import algo, utils
from a2c_ppo_acktr.envs import make_vec_envs
from a2c_ppo_acktr.model import Policy
from a2c_ppo_acktr.algo.gail import ExpertDataset
from a2c_ppo_acktr.storage import RolloutStorage
from a2c_ppo_acktr.arguments import get_args
from evaluation import evaluate
from CNF import *
from datetime import datetime
import wandb

args = get_args()

if args.adjoint:
    from torchdiffeq import odeint_adjoint as odeint
else:
    from torchdiffeq import odeint

tags = ['NeuralODE', f'Game_{args.env_name}with{args.num_demo}trajs']
if args.log_wandb:
    wandb.init(name=f"({args.env_name},{args.num_demo})",
               project=f"CNF_{args.env_name}",
               tags=tags)
    wandb.config.update(args)
datetime_now = datetime.now().strftime("%Y%m%d-%H%M%S")
args.save_dir = os.path.join(args.save_dir, "CNF", args.env_name, str(args.num_demo), str(datetime_now))

# get expert demonstration
file_name = os.path.join(args.experts_dir, "{}.h5".format(args.env_name))
expert_dataset = ExpertDataset(file_name, num_trajectories=args.num_demo, subsample_frequency=args.subsample_frequency)

drop_last = False #len(expert_dataset) > args.batch_size
train_loader = torch.utils.data.DataLoader(dataset=expert_dataset, batch_size=args.batch_size,
                                           shuffle=True, drop_last=drop_last)

for i, (states, actions, seqs) in enumerate(train_loader):
    if i == 0:
        minvalues = torch.min(states, axis=0).values
        maxvalues = torch.max(states, axis=0).values
    minvalues = torch.min(minvalues, torch.min(states, axis=0).values)
    maxvalues = torch.max(minvalues, torch.max(states, axis=0).values)

t0 = 0        
t1 = 10
device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu')

# model
envs = make_vec_envs(args.env_name, args.seed, args.num_processes, args.gamma, args.log_dir, device, True)

func = CNF(in_out_dim=envs.observation_space.shape[0], hidden_dim=args.hidden_dim, width=args.width).to(device)

p_z0 = torch.distributions.MultivariateNormal(
    loc=torch.tensor(np.zeros(envs.observation_space.shape[0])).to(device),
    covariance_matrix=torch.tensor(np.identity(envs.observation_space.shape[0])).to(device))
func_optimizer = optim.Adam(func.parameters(), lr=args.lr)
envs.close()
    
loss_meter = RunningAverageMeter()

if args.save_dir is not None:
    if not os.path.exists(os.path.join(args.save_dir, str(datetime_now))):
        os.makedirs(args.save_dir)
    ckpt_path = os.path.join(args.save_dir, 'ckpt.pth')
    if os.path.exists(ckpt_path):
        checkpoint = torch.load(ckpt_path)
        func.load_state_dict(checkpoint['func_state_dict'])
        func_optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        print('Loaded ckpt from {}'.format(ckpt_path))
try:
    for itr in range(1, args.niters + 1):
        
        '''Train func to get a good estimation of P_E(s_t)'''
        print('Training P_E(s_t) ... ')
#         for sub_itr in range(1, 100 + 1):
        for i, (states, actions, seqs) in enumerate(train_loader):

            '''0 torch.Size([128, 4]) torch.Size([128, 1]) torch.Size([128])
               torch.Size([100, 1]), torch.Size([100, 1])
            '''   
            func_optimizer.zero_grad()
            
            states = (states-minvalues)/(maxvalues-minvalues+1e-8)

            x, logp_diff_t1 = states.to(device), torch.zeros(states.size()[0], 1).type(torch.float32).to(device)

            z_t, logp_diff_t = odeint(    # the ODEsolver
                func,
                (x, logp_diff_t1),
                torch.tensor([t1, t0]).type(torch.float32).to(device),
                atol=1e-5,
                rtol=1e-5,
                method='dopri5',
            )
            '''from ffjord: time from t_0 to t_1
            z_0             = x    + \int f    dt
            logp_x-logp_z0  = 0           tr
            '''
            z_t0, logp_diff_t0 = z_t[-1], logp_diff_t[-1]

            logp_x = p_z0.log_prob(z_t0).view(-1).to(device) - logp_diff_t0.view(-1)
            loss = - logp_x.mean(0)

            loss.backward()
            func_optimizer.step()

            loss_meter.update(loss.item())

            if args.log_wandb:
                wandb.log({"CNF loss": loss_meter.avg,
                           "epoch": itr})
            print('---CNF Iter: {}, running avg loss: {:.4f}'.format(itr, loss_meter.avg))

            ckpt_path = os.path.join(args.save_dir, '{}_{}_ckpt.pth'.format(itr, i))
            torch.save({
                'func_state_dict': func.state_dict(),
                'optimizer_state_dict': func_optimizer.state_dict(),
            }, ckpt_path)
            print('Stored ckpt at {}'.format(ckpt_path))

            
except KeyboardInterrupt:
    if args.save_dir is not None:
        ckpt_path = os.path.join(args.save_dir, 'ckpt.pth')
        torch.save({
            'func_state_dict': func.state_dict(),
            'optimizer_state_dict': func_optimizer.state_dict(),
        }, ckpt_path)
        print('Stored ckpt at {}'.format(ckpt_path))
print('Training complete after {} iters.'.format(itr))
